# This is Ultisnip.
# seealso :help snippet-options

global !p
import re
import px

snippets_helper = px.__import__('snippets_helper')


def toCamelCase(s):
    return ''.join(x for x in s.title() if x.isalnum())

def current_lnum():
    # Get the current line number (int), 1-indexed.
    return vim.current.window.cursor[0]

def collect_lsp_lint_codes(linter_name):
    lnum = current_lnum()
    codes = set()

    # Note that on dynamic buffer change (e.g. inserting new line),
    # the diagnostics table will not have been updated yet and thus have some offset
    diagnostics = vim.eval('luaeval("vim.diagnostic.get(0)")')

    # Logic: find the first line >= current_lnum that has a diagnostic,
    # and collet all the matching (by source) diagnostic codes.
    target_lnum = None
    for d in diagnostics:
        if d['source'] != linter_name:
            continue
        l = int(d['lnum']) + 1   # vim.diagnostic uses 0-index; convert to 1-index.
        if l >= lnum:
            if target_lnum is None:
                target_lnum = l
            else:
                # assuming loclist is in an ascending order of line number
                break
        if l == target_lnum:
            codes.add(d['code'])

    return ",".join(codes)

endglobal


################################################################################
# General
################################################################################

snippet TODO  "# TODO"
# TODO
endsnippet
snippet utf8
# -*- coding: utf-8 -*-
endsnippet
snippet future
from __future__ import annotations
endsnippet


################################################################################
# Class
################################################################################

snippet init
def __init__(self, ${1}):
	`!p
import ast
try:
	T = ast.parse("def __init__({}): pass".format(t[1]))
	args = [a.arg for a in T.body[0].args.args]
	line = lambda arg: "self._{arg} = {arg}".format(arg=arg)
	snip.rv = snip.mkline("super().__init__()\n")
	if args:
		snip.shift(1)  # no idea why I have to do this...
		for arg in args:
			snip.rv += snip.mkline(line(arg) + "\n")
	else:
		snip.rv = "pass"
except SyntaxError:  # incomplete?
	snip.rv = "# (SyntaxError)"
`
endsnippet

snippet self._
self._${1} = $1
endsnippet
snippet @property
@property
def ${1}(self):
	return self._$1
endsnippet
snippet repr
def __repr__(self):
	return ${1:str(self)}
endsnippet
snippet "@?dataclass" "@dataclasse.dataclass" r
@dataclasses.dataclass
class ${1}:
	pass
endsnippet


################################################################################
# Functional Programming
################################################################################

snippet wraps  "functools.wraps"
@functools.wraps(${1:fn})
def _wrapped(*args, **kwargs):
	return $1(*args, **kwargs)
endsnippet


################################################################################
# Linting
################################################################################

snippet "^(\s*)pylint"  "'pylint: disable-next=...' marker" r
`!p snip.rv = match.group(1)`# pylint: disable-next=${1:`!p snip.rv = collect_lsp_lint_codes("pylint")`}
endsnippet
priority -1
snippet "(#\s*)?pylint"  "'pylint: disable=...' marker" r
# pylint: disable=${1:`!p snip.rv = collect_lsp_lint_codes("pylint")`}
endsnippet
priority 0

snippet "#\s*type"  "'type: ignore' marker (mypy, pyright)" r
# type: ignore
endsnippet
snippet "ignore" "'type: ignore' marker (mypy, pyright)" r
# type: ignore
endsnippet
snippet "type:?" "'type: ignore' marker (mypy, pyright)" r
# type: ignore
endsnippet
snippet "yapf"  "yapf: disable/enable" r
# yapf: disable
endsnippet


################################################################################
# Debugging
################################################################################

snippet embed
from IPython import embed; embed(colors="neutral")  # XXX DEBUG  # yapf: disable
endsnippet
snippet pt "ptipython's embed"
import ptpython.ipython; ptpython.ipython.embed(configure=ptpython.repl.run_config)  # XXX DEBUG  # yapf: disable
endsnippet
snippet pdb
import pdb; pdb.set_trace()  # XXX DEBUG  # yapf: disable
endsnippet
snippet ipdb
import ipdb; ipdb.set_trace(cond=True)  # XXX DEBUG  # yapf: disable
endsnippet
snippet pudb
import pudb; pudb.set_trace(paused=True)  # XXX DEBUG  # yapf: disable
endsnippet
snippet nest_asyncio "nest_asyncio.apply"
import nest_asyncio; nest_asyncio.apply()  # yapf: disable
endsnippet

################################################################################
# Common Snippets and Boilerplates
################################################################################

snippet main
def main():
	${1:pass}

if __name__ == '__main__':
	main()
endsnippet
snippet path "script path (using os.path)"
__PATH__ = os.path.abspath(os.path.dirname(__file__))
endsnippet
snippet PATH "script path (using pathlib)"
__PATH__ = Path(__file__).resolve().parent
endsnippet

# Add a typecheck if-raise statement.
# TODO: auto-retrieve type annotation from the current python context.
snippet typecheck
if not isinstance(${1}, ${2:object}):
	raise TypeError("\`$1\` must be a {}, but given {}".format($2, type($1)))
endsnippet
snippet TypeError
TypeError("\`$1\` must be a ${2:object}, but given {}".format($2, type($1)))
endsnippet

snippet rangecheck
if not ${3:0} <= ${1} < ${2:n}:
	raise IndexError("Index out of range [{}, {}), but given {}".format($3, $2, $1))
endsnippet
snippet print_exception
traceback.print_exception(*sys.exc_info())
endsnippet
snippet pprint
from pprint import pprint
endsnippet
snippet cpprint
from prettyprinter import cpprint
endsnippet
snippet filterwarnings
warnings.filterwarnings("ignore", category=${1:DeprecationWarning}, message="$2")
endsnippet
snippet argparse
import argparse
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('${1:-o}', ${2:'--outfile',} default=${3:None}, type=${4:str})
args = parser.parse_args()
endsnippet
snippet jsondump
json.dumps(${1:object}, sort_keys=True, indent=4, separators=(',', ': '))
endsnippet
snippet easydict
from easydict import EasyDict as edict
endsnippet
snippet osp
import os.path as osp
endsnippet

# for competitive programming
snippet freopen
sys.stdin = file('${1:input.txt}', 'r')
sys.stdout = file('${2:output.txt}', 'w')
rl = sys.stdin.readline
endsnippet

# rich
snippet rich.console
from rich.console import Console
console = Console(markup=False)
rich_print = console.log
endsnippet
snippet rich.traceback
import rich.traceback; rich.traceback.install(width=None, show_locals=True)
endsnippet

# temp files
snippet tmpdir
tempfile.mkdtemp(suffix=None)
endsnippet

# misc utilities
snippet redirect_stdout "capture stdout (useful for testing)"
with contextlib.redirect_stdout(io.StringIO()) as f:
    ${1:pass}
stdout = f.getvalue()
$0
endsnippet
snippet strftime
time.strftime(${1:'%Y%m%d-%H%M%S'})
endsnippet
snippet timer
from contextlib import contextmanager

@contextmanager
def timer(name=""):
	import time
	_start = time.time()
	yield
	_end = time.time()
	print("[%s] Elapsed time : %.3f sec" % (name, _end - _start))
endsnippet

################################################################################
# Logging
################################################################################
# f-strings
snippet f"
f"{$1}"
endsnippet
snippet f=
f"{$1=}"
endsnippet
snippet f"=
f"{$1=}"
endsnippet
# print and logging
snippet p
print(f"{$1=}")
endsnippet
snippet log
log.info($1)
endsnippet
snippet logw
log.warning($1)
endsnippet
snippet loge
log.error($1)
endsnippet

snippet loguru
from loguru import logger
endsnippet
snippet loguru.configure
logger.configure(handlers=[])
LOG_FORMAT = (   # see loguru._defaults.LOG_FORMAT
    "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
    "<level>{level: <8}</level> | "
    #"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
    "<level>{message}</level>"
)
logger.add(sink=sys.stdout, format=LOG_FORMAT)
endsnippet

snippet logging
import logging
formatter = logging.Formatter('[%(levelname)s %(asctime)s] %(message)s')   # colorlog?
ch = logging.StreamHandler()
ch.setFormatter(formatter)
ch.setLevel(logging.DEBUG)

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
log.handlers = []       # No duplicated handlers
log.propagate = False   # workaround for duplicated logs in ipython
log.addHandler(ch)
endsnippet

snippet colorlog
from colorlog import ColoredFormatter   # pip install colorlog
formatter = ColoredFormatter(
    "%(log_color)s[%(asctime)s] %(message)s",
    datefmt=None, reset=True, style='%',
    log_colors={
        'DEBUG':    'cyan',
        'INFO':     'white,bold',
        'INFOV':    'cyan,bold',
        'WARNING':  'yellow',
        'ERROR':    'red,bold',
        'CRITICAL': 'red,bg_white',
    })
endsnippet


################################################################################
# Test
################################################################################
post_jump "snippets_helper.snip_expand(snip, 1, jump_forward=True)"
snippet pytest
import sys
import pytest


class Test${2:`!p snip.rv = toCamelCase(re.sub("_test$", "", snip.basename))`}:
	def test_hello(self):
		assert False, "A test case."


pytest.main$1
endsnippet

snippet pytest.main
if __name__ == '__main__':
	sys.exit(pytest.main(["-s", "-v"] + sys.argv))
endsnippet
snippet @pytest.mark.parametrize
@pytest.mark.parametrize((${1:"arg0", "arg1"}), [
	(${2:"val0", "val1"}),
])
endsnippet
snippet pytest.param
pytest.param(${1}, id=None, marks=${3:pytest.mark.xfail})
endsnippet
snippet pytest.fixture
@pytest.fixture
def ${1:fixture_name}($2):
	return $3
endsnippet
snippet pytest.xfail
@pytest.mark.xfail(reason="${1:xfail}")
endsnippet
snippet pytest.filterwarnings
@pytest.mark.filterwarnings("${1:ignore:DeprecationWarning}")
endsnippet
snippet pytest.raises
with pytest.raises(${1:ValueError}, match='.*${2:some regex}.*'):
	${3:pass}
endsnippet

snippet unittest
class Test${1:Example}(unittest.TestCase):
	def testHello(self):
		self.assertFalse("Implement me!")
endsnippet

snippet assert_array_equal
np.testing.assert_array_equal($1, $2)
endsnippet


################################################################################
# Profiling
################################################################################

snippet line_profiler
from importlib import import_module  # isort: skip
from line_profiler import LineProfiler  # isort: skip
lp = LineProfiler()
lp.add_module(import_module("$1"))
lp.runcall(${2:main})
lp.print_stats(stream=import_module('sys').stderr, output_unit=1e-3, \
               stripzeros=True, sort=True, rich=True, summarize=True)
endsnippet

################################################################################
# Ipython
################################################################################

snippet ipython.system "Import system from ipython (same as !shell execute magic)" b
from IPython.utils._process_posix import system, getoutput
endsnippet


################################################################################
# ML & Data Science
################################################################################

# matplotlib
snippet mpl_colorbar
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from mpl_toolkits.axes_grid1.colorbar import colorbar
cax = make_axes_locatable(${1:ax}).append_axes("right", size="5%", pad=0.05)
colorbar(${2:im}, cax=cax)
endsnippet
